# 数据存放目录
import glob
import os
from keras import Model, optimizers, Input
from keras.applications import InceptionV3, ResNet50
# from keras.applications.inception_v3 import preprocess_input
from keras.applications.imagenet_utils import preprocess_input
from keras.layers import GlobalAveragePooling2D, Dense
from keras.models import load_model
from keras.preprocessing.image import ImageDataGenerator
from PIL import Image
import matplotlib.pyplot as plt

from datasets import load_dataset
from utils import plot_training


# INPUT_DATA = 'E:/cdn/filtered'
# INPUT_DATA = 'E:/cdn/comment'
# FC_SIZE = 500  # 全连接层的节点个数


# train_set_x_orig, train_set_y_orig, test_set_x_orig, test_set_y_orig = load_dataset(INPUT_DATA)
# X_train = train_set_x_orig / 255.
# X_test = test_set_x_orig / 255.
# Y_train = train_set_y_orig.T
# Y_test = test_set_y_orig.T
# num = X_train.shape[0]
# print('num=' + str(num))

# print(train_set_x_orig.shape,train_set_y_orig.shape,test_set_x_orig.shape,test_set_y_orig.shape)
#
# index = 6
# plt.imshow(train_set_x_orig[index])
# print("y = " + str(np.squeeze(train_set_y_orig[index])))
# plt.show()






def get_nb_files(directory):
    """Get number of files by searching directory recursively"""
    if not os.path.exists(directory):
        return 0
    cnt = 0
    for r, dirs, files in os.walk(directory):
        for dr in dirs:
            cnt += len(glob.glob(os.path.join(r, dr + "/*")))
    return cnt


# 　图片生成器
train_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True
)
test_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True
)


# train_generator = train_datagen.flow(X_train, Y_train)
# test_generator = test_datagen.flow(X_test, Y_test)


# 添加新层
def add_new_last_layer(base_model, nb_classes, fc_size=1024):
    """
    添加最后的层
    输入
    base_model和分类数量
    输出
    新的keras的model
    """
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(fc_size, activation='relu')(x)  # new FC layer, random init
    predictions = Dense(nb_classes, activation='softmax')(x)  # new softmax layer
    model = Model(inputs=base_model.input, outputs=predictions)
    return model


# 冻上base_model所有层，这样就可以正确获得bottleneck特征
def setup_to_transfer_learn(model, base_model):
    """Freeze all layers and compile the model"""
    for layer in base_model.layers:
        layer.trainable = False
    lr = 0.0001
    adam = optimizers.adam(lr)
    model.compile(optimizer=adam, loss='binary_crossentropy', metrics=['accuracy'])


IM_WIDTH, IM_HEIGHT = 224, 224  # InceptionV3指定的图片尺寸
FC_SIZE = 1024  # 全连接层的节点个数
train_dir = 'E:/cdn/comment'  # 训练集数据
val_dir = 'E:/cdn/comment_val'  # 验证集数据
nb_classes = len(os.listdir(train_dir))
nb_epoch = 3
batch_size = 32

nb_train_samples = get_nb_files(train_dir)  # 训练样本个数
# nb_classes = len(glob.glob(train_dir + "/*"))  # 分类数
nb_val_samples = get_nb_files(val_dir)  # 验证集样本个数
# 训练数据与测试数据
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(IM_WIDTH, IM_HEIGHT),
    batch_size=batch_size, class_mode='categorical')

validation_generator = test_datagen.flow_from_directory(
    val_dir,
    target_size=(IM_WIDTH, IM_HEIGHT),
    batch_size=batch_size, class_mode='categorical')

# 定义网络框架
base_model = ResNet50(weights='imagenet', include_top=False,
                      input_tensor=Input(shape=(IM_WIDTH, IM_HEIGHT, 3)))  # 预先要下载no_top模型
model = add_new_last_layer(base_model, nb_classes, fc_size=FC_SIZE)  # 从基本no_top模型上添加新层
setup_to_transfer_learn(model, base_model)  # 冻结base_model所有层

# model.fit(x=X_train, y=Y_train, batch_size=128, epochs=50)
history = model.fit_generator(
    train_generator,
    epochs=nb_epoch * 5,
    steps_per_epoch=nb_train_samples // 5,
    # validation_data=validation_generator,
    class_weight="auto")
model_save = os.path.join(train_dir, 'unopen_model.h5')
model.save(model_save)
plot_training(history)
# model = load_model('E:/unopen_model.h5')
# print(Y_test)
# preds = model.evaluate(x=X_test, y=Y_test)
# print()
# print('Loss=' + str(preds[0]))
# print('Test Accuracy = ' + str(preds[1]))
# model.summary()
